iT邦幫忙

2021 iThome 鐵人賽

DAY 15
1
自我挑戰組

資料分析及AI深度學習-簡單基礎實作系列 第 15

DAY15:玉山人工智慧挑戰賽-中文手寫字辨識(Pytorch 自訂義資料集)

  • 分享至 

  • xImage
  •  

資料擴增

  • 我們組的資料擴增這部分,因為第一次比賽,這個方法效果沒有到非常好,採取的是用mask的方式,讓圖檔多加一些遮蔽物,如下圖。詳細操作參考組員的分享(傳送門)
  • 增加完我們的圖片總數量約為19萬張。

Pytorch自定義資料集

  • 我們先定義一個alphabet,它代表的是我們的800個字的位置。
img_names = os.listdir(data_path)
source = img_names
alphabet = ''.join(source)

  • 因為我們到時候會將圖檔用PIL的Image讀取出來,所以先將圖檔和對應的label組成一個list。
# 讀取圖檔,並轉換大小為80*80,以及轉換成RGB
def img_loader(img_path):
    image = Image.open(img_path)
    img = image.resize((80, 80),Image.ANTIALIAS) #resize image with high-quality
    return img.convert('RGB')
# 將圖檔與label對應,丟入自定義的資料集內
def make_dataset(data_path, alphabet, num_class):
    samples = []
    for i in os.listdir(data_path):
        for j in os.listdir(data_path + '/' + i):
            img_path = data_path + '/' + i + '/' + j
            target_str = j.split('.')[0][-1]
            vec = [0] * 800
            vec[alphabet.find(target_str)] = 1
            target = vec
            samples.append((img_path, target))
    return samples

例如這個字是"不",由alphabet的位置可以看到alphabet[3]的位置是"不",故在alphabet[3]的位置為1,代表他的label,其餘位置皆為0。

  • torch.utils.data.Dataset,是一個自定義資料集的框架。

    • __ init __()

      • 負責做一個初始化的動作,我們先定義我們要的東西:
        1. data_path:我們要讀取資料集的路徑。
        2. num_class:我們要預測的種類數量(800類)
        3. transform:對於圖片是否進行處理,這裡設定None,不對讀取進來的圖片作處理。
        4. target_transform:對標籤做處理,這裡我們也都處理好了,不對標籤做處理,設定為None。
        5. alphabet:我們要預測的所有字,將它變成str。
        6. samples:是我們用make_data弄成的圖片與label對應的list。
      def __init__(self, data_path, num_class=800,transform=None,target_transform=None, alphabet=alphabet):
          super(Dataset, self).__init__()
          self.data_path = data_path
          self.num_class = num_class
          self.transform = transform
          self.target_transform = target_transform
          self.alphabet = alphabet
          self.samples = make_data.set(self.data_path, self.alphabet)
      
    • __ len __ ()

      • 返回list中的長度,也就是你的資料的筆數。
      def __len__(self):
          return len(self.samples)
      
    • __ getitem __ ()

      • 使資料集可以節省內存,資料集為dataset,而__ len __ ()返回的數字n,使的dataset[n]的圖片能被讀取,需要時才將圖片讀取,所以可以節省內存。返回值一個圖片樣本及標籤。
      def __getitem__(self, index):
          img_path, target = self.samples[index]
          img = img_loader(img_path)
          if self.transform is not None:
              img = self.transform(img)
          if self.target_transform is not None:
              target = self.target_transform(target)
          return img, torch.Tensor(target) # 在torch裡面,array都要轉成Tensor型式
      
    • 完整程式碼

    class CaptchaData(Dataset):
      def __init__(self, data_path, num_class=800,
                   transform=None, target_transform=None, alphabet=alphabet):
          super(Dataset, self).__init__()
          self.data_path = data_path
          self.num_class = num_class
          self.transform = transform
          self.target_transform = target_transform
          self.alphabet = alphabet
          self.samples = make_dataset(self.data_path, self.alphabet)
    
      def __len__(self):
          return len(self.samples)
    
      def __getitem__(self, index):
          img_path, target = self.samples[index]
          img = img_loader(img_path)
          if self.transform is not None:
              img = self.transform(img)
          if self.target_transform is not None:
              target = self.target_transform(target)
          return img, torch.Tensor(target)
    
  • torch.utils.data.DataLoader

    • Dataset設置好後,DataLoader可以依照batch_size讓我們取樣,非常方便。

      from torchvision.transforms import Compose, ToTensor
      from torch.utils.data import DataLoader
      transforms = Compose([ToTensor()])
      train_dataset = CaptchaData(r'C:\Users\Frank\PycharmProjects\practice\mountain\data_final_20210530\official_in_800',transform=transforms)
      train_data_loader = DataLoader(train_dataset, batch_size=1, num_workers=0,
                                         shuffle=True, drop_last=True)
      
      for (data,label) in train_data_loader:
          print((data,label))
      

      我把batch_size設定為1,他一次就只取出一組圖片樣本及標籤。

      • batch_size:一次要取多少樣本。
      • num_workers:數據加載的子進程數,0則為主要進程。(這裡有個小坑,我只要條不是0的時候,都會發生error,如果有知道的大神,幫我指點一下,小弟非常感謝。)
      • shuffle:若設定為True,指每個epoch取出樣本的順序都會不一樣。
      • drop_last:若設定為True,則全部樣本不能被batch_size整除時,最後一批會直接被刪除;若為False,則最後一批會較小。
  • 除了自定義資料集以外,還有可以torchvision.datasets.ImageFolder来處理資料集,用法會在於你分好類別,他的資料夾名稱就是他的label,而裡面圖片都屬於這個label。


今日小結

  • 深度學習有很多很好玩的地方,但也有很多的坑,debug我都要找很久XDD,重點是東西太多,絕對學不完,而且很吃硬體設備。有時候會覺得自己好笨,都學不會,但看久了發現懂一點了,就又有動力繼續往下學了,接觸深度學習的朋友們,我們一起繼續努力吧!

  • 小弟我是試著用自定義資料集來處理,原因只想練習以及可以更彈性的操作載入資料的動作。

  • 前面加載圖片時我們把transforms設置為None,現在我們丟模型訓練要對圖片做transforms,他可以增加圖片的多樣性,例如:旋轉、平移、變形等等,明天來跟大家分享torchvision很好用的套件transforms。


上一篇
DAY14:玉山人工智慧挑戰賽-中文手寫字辨識(OpenCV圖像處理)
下一篇
DAY16:Pytorch transforms(上)
系列文
資料分析及AI深度學習-簡單基礎實作30
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言